#include "statsxx/machine_learning/RBM.hpp" // this->calculate_p(), this->sample()

// jScience
#include "jScience/linalg.hpp"              // Vector<>
#include "jrandnum.hpp"                     // rand_num_normal_Mersenne_twister()


//
//
//
inline Vector<double> machine_learning::RBM::gen_sample(
                                                        const int  niter,
                                                        // -----
                                                        const bool sample_v
                                                        ) const
{
    // TODO: Just after training, could calculate some statistics about the visible and hidden states; this would allow us to set an appropriate cutoff.  
    const double cutoff = 4.;
    
    //=========================================================
    
    Vector<double> v;
    
    //---------------------------------------------------------
    // CREATE A RANDOM HIDDEN VECTOR
    //---------------------------------------------------------
    
    // NOTE: The hidden vector is randomly set, since we know the range.
    
    Vector<double> h(this->nh);
    
    for(int i = 0; i < this->nh; ++i)
    {
        switch(this->htype)
        {
            case 0:
                h(i) = rand_num_normal_Mersenne_twister(0.,1.);
                break;
            case 1:
                h(i) = rand_num_normal_Mersenne_twister(-cutoff,cutoff);
                break;
            case 2:
                // NOTE: Let ReLU fall through to softplus ...
            case 3:
                h(i) = rand_num_normal_Mersenne_twister(0.,cutoff);
                break;
            default:
                break;
        }
    }
    
    // NOTE: There is no need to sample, since this is already (and completely) randomly set.
    
    //---------------------------------------------------------
    // NEGATIVE PHASE (ITERATION 0)
    //---------------------------------------------------------
    
    v = this->W*h + this->a;
    
    Vector<double> pv = this->calculate_p(
                                          v,
                                          this->vtype
                                          );
    
    if(sample_v)
    {
        v = this->sample(
                         v,
                         pv,
                         this->vtype
                         );
    }
    else
    {
        v = pv;
    }
    
    //---------------------------------------------------------
    // GIBBS SAMPLE (REMAINING ITERATIONS)
    //---------------------------------------------------------
    
    for(int i = 1; i < niter; ++i)
    {
        // ----- POSITIVE PHASE -----
        h = transpose(this->W)*v + this->b;
        
        Vector<double> ph = this->calculate_p(
                                              h,
                                              this->htype
                                              );
        
        h = this->sample(
                         h,
                         ph,
                         this->htype
                         );
        
        // ----- NEGATIVE PHASE -----
        v = this->W*h + this->a;
        
        pv = this->calculate_p(
                               v,
                               this->vtype
                               );
        
        if(sample_v)
        {
            v = this->sample(
                             v,
                             pv,
                             this->vtype
                             );
        }
        else
        {
            v = pv;
        }
    }

    return v;
}

